import torch
from math import pi, cos
from scipy import stats


def fold(data):
    """
    :param data: [T, B, *]
    :return: [T x B, *]
    """
    if data.ndim < 3:
        raise IndexError("Input data need least 3 dimensions.")
    shape = [-1] + list(data.size()[2:])
    return data.reshape(shape)


def unfold(data, time_dim):
    """
    :param data: [N, *]
    :param time_dim: Time dimension
    :return: [T, B, *]
    """
    if data.ndim < 2:
        raise IndexError("Input data need least 2 dimensions.")
    if time_dim < 1:
        raise ValueError("`time_dim` need to be an integer not smaller than 1.")
    shape = [time_dim, -1] + list(data.size()[1:])
    return data.reshape(shape)


def copy_code(data, num_step):
    """
    Repeat original data directly, is not spike
    """
    if data.ndim != 4:
        raise IndexError("Input data's dimension need to be 4.")
    return data.repeat([num_step] + [1] * data.ndim)


def dvs_code(data, num_step=0):
    """
    :param data: [B, T, C, H, W]
    :return: [T, B, C, H, W]
    """
    if data.ndim != 5:
        raise IndexError("This method onl for dvs data with shape of [B, T, C, H, W].")
    return data.permute(1, 0, 2, 3, 4)


# ==================================== Optimizer Updater ======================================
class WarmUpdate():
    def __init__(self, epoch, lr):
        """
        Note: If change lr after initialize, reset k
        :param epoch: Iterate time
        :param lr: Max learning
        """
        if epoch < 30:
            raise ValueError("Can't initialize `CosUpdate` class, `epoch` need bigger than 30")
        self.lr = lr
        self.line_epoch = epoch // 20
        self.epoch = epoch
        self.k = self.lr * 0.9 / self.line_epoch
        self.step = 0
        self.gap = pi / (self.epoch - self.line_epoch)

    def lr_update(self, ep, optimizer):
        if ep == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = self.lr * 0.1
                if 'weight_decay' in param_group.keys():
                    param_group['weight_decay'] = 1e-4
                if 'momentum' in param_group.keys():
                    param_group['momentum'] = 0.9
            return
        if ep < self.line_epoch:
            for param_group in optimizer.param_groups:
                param_group['lr'] += self.k
            return
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr * 0.5 * (1 + cos(self.step))
        self.step += self.gap


class ImgUpdate():
    def __init__(self, epoch, lr):
        self.lr = lr
        self.epoch = epoch
        self.step = 0
        self.gap = pi / epoch

    def lr_update(self, ep, optimizer):
        if ep == 0:
            for param_group in optimizer.param_groups:
                if 'weight_decay' in param_group.keys():
                    param_group['weight_decay'] = 4e-5
                if 'momentum' in param_group.keys():
                    param_group['momentum'] = 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr * 0.5 * (1 + cos(self.step))
        self.step += self.gap


# ==================================== K and Gap Method =======================================
def cal_k(mu, sigma, gap, threshold=1.0):
    """
    Not test sigma, test should be done by caller
    :return: [threshold - gap, threshold + gap]
    """
    return stats.norm.cdf(threshold + gap, mu, sigma) - stats.norm.cdf(threshold - gap, mu, sigma)


def cal_half_k(mu, sigma, gap, threshold=1.0):
    """
    :return: [threshold, threshold + gap]
    """
    return stats.norm.cdf(threshold + gap, mu, sigma) - stats.norm.cdf(threshold, mu, sigma)


def cal_right_k(mu, sigma, gap, threshold=1.0):
    """
    :return: [threshold, oo)
    """
    return 1 - stats.norm.cdf(threshold, mu, sigma)


def sub_gap(cal_k_method, mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    """
    All input argument should not be torch.Tensor
    :param cal_k_method: choose from cal_k and cal_half_k
    :param k: target k
    :param last_gap: current old gap
    Note: when sigma is too small, result is not accurate, so forbidden.
    """

    if epi < 0:
        raise ValueError("epi need not smaller than 0.")
    if sigma <= epi:
        if sigma < 0:
            raise ValueError("Negative sigma, some wrong")
        print("Sigma is too small, return the old gap.")
        return last_gap
    cur_k = cal_k_method(mu, sigma, last_gap, threshold)
    if cur_k - k <= epi:  # current k is not big enough to decrease
        return last_gap
    else:
        end = last_gap
        start = last_gap / 2
        temp = cal_k_method(mu, sigma, start, threshold)
        while temp >= k - epi:
            if temp <= k + epi:
                return start
            else:  # temp > k + epi
                end = start
                start /= 2
                temp = cal_k_method(mu, sigma, start, threshold)
        while start + epi < end:
            mid = (start + end) / 2
            temp = cal_k_method(mu, sigma, mid, threshold)
            if (temp - k <= epi) and (temp - k >= -epi):
                return mid
            elif temp > k:
                end = mid
            else:
                start = mid
        return (start + end) / 2


def cal_gap(mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    return sub_gap(cal_k, mu, sigma, k, last_gap, threshold, epi)


def cal_half_gap(mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    return sub_gap(cal_half_k, mu, sigma, k, last_gap, threshold, epi)


def back_gap(cal_k_method, mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    """
    Test sigma in caller
    :return: The new_gap satisfy cal_k_method(mu, sigma, new_gap, threshold) == k
    """
    if epi < 0:
        raise ValueError("epi need not smaller than 0.")
    cur_k = cal_k_method(mu, sigma, last_gap, threshold)
    if (cur_k - k <= epi) and (cur_k - k >= -epi):
        return last_gap
    elif cur_k > k:
        end = last_gap
        start = last_gap / 2
        temp = cal_k_method(mu, sigma, start, threshold)
        while temp >= k - epi:
            if temp <= k + epi:
                return start
            else:
                end = start
                start /= 2
                temp = cal_k_method(mu, sigma, start, threshold)
    else:  # need increase gap
        end = last_gap * 2
        start = last_gap
        temp = cal_k_method(mu, sigma, end, threshold)
        while temp <= k + epi:
            if temp >= k - epi:
                return end
            else:
                start = end
                end *= 2
                temp = cal_k_method(mu, sigma, end, threshold)
    while start + epi < end:
        mid = (start + end) / 2
        temp = cal_k_method(mu, sigma, mid, threshold)
        if (temp - k <= epi) and (temp - k >= -epi):
            return mid
        elif temp > k:
            end = mid
        else:
            start = mid
    return (start + end) / 2


def back_cal_gap(mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    if epi < 0:
        raise ValueError("epi need not smaller than 0.")
    if sigma <= epi:
        if sigma < 0:
            raise ValueError("Negative sigma, some wrong")
        print("Sigma is too small, return the old gap.")
        return last_gap
    if k + epi >= 1:
        # For we use a small limit_k, this condition means not try to increase k
        # Must try to decrease k, so cur_k is bigger than k, two k is similar, so not adjust
        return last_gap
    return back_gap(cal_k, mu, sigma, k, last_gap, threshold, epi)


def back_cal_half_gap(mu, sigma, k, last_gap, threshold=1.0, epi=1e-8):
    if epi < 0:
        raise ValueError("epi need not smaller than 0.")
    if sigma <= epi:
        if sigma < 0:
            raise ValueError("Negative sigma, some wrong")
        print("Sigma is too small, return the old gap.")
        return last_gap
    if k + epi >= cal_right_k(mu, sigma, last_gap, threshold):
        # means cur_k in (k, k + epi), so not adjust
        return last_gap
    return back_gap(cal_half_k, mu, sigma, k, last_gap, threshold, epi)
